Tart – A plug-and-play Transformer module for task-agnostic reasoning

Large Language Models
Author

Maxime Lbonne

Published

January 6, 2024

Tip

Tart combines the performance of fine-tuning with the ease-of-use of in-context learning. It is a general framework that leverages embeddings produced by LLMs with an improved task-agnostic reasoning module.

📝 Paper: https://arxiv.org/abs/2306.07536

💻 GitHub: https://github.com/HazyResearch/TART

Task Adaptation Techniques

Given an LLM and limited labeled data for a task, how does one adapt the model to the task? We care about the following properties:

  • Task-agnostic: We want to use the exact same model for different tasks.
  • Quality: Performance should be competitive with task-specific methods.
  • Data-scalable: The more data, the better the performance.

Existing methods:

  • In-context learning: Based on input, no parameter update. It is task-agnostic, but doesn’t have the same level of performance and is constrained by the context length.
  • Fine-tuning: Updates all the parameters. It is task-dependent.
  • Adapters: Additional set of parameters that are updated for a given task. Performance is competitive, but it is task-dependent too.

Representation vs. Reasoning

Why ICL underperforms other methods? The authors assume it is either because:

  1. LLMs cannot generate good representations for the specific task.
  2. LLMs cannot perform probabilistic inference or reasoning using these representations.

The researchers use linear probing, a method that involves training a task-specific linear classifier using the representations generated by the LLM, to evaluate the information content of the representations.

They then decompose the performance gap between FT and ICL into two components: Δrep, which represents the performance gap due to insufficient representation capacity, and Δreas, which represents the performance gap due to insufficient reasoning abilities. \begin{align} \Delta_{\text{perf}} &= \text{Acc}_{\text{FT}} - \text{Acc}_{\text{ICL}} \\ &= \text{Acc}_{\text{FT}} - \text{Acc}_{\text{LR}} + \text{Acc}_{\text{LR}} - \text{Acc}_{\text{ICL}} \\ &= \Delta_{\text{rep}} + \Delta_{\text{reas}} \end{align}

Using the experiments from the previous figure, they make the following observations:

  1. LLMs lack reasoning abilities, despite having sufficient information in their representations (Figure a).
  2. Fine-tuning improves task-specific reasoning, accounting for 73.06% of the improvements (Figure b).
  3. Adapters only learn task-specific reasoning abilities.

Tart: Task-Agnostic Reasoning Transformers

Tart learns an LLM and task-agnostic reasoning module without any task-specific training. It has two components:

  1. A genetic task-agnostic reasoning module, training on synthetic data (Gaussian logistic regression problems) with the objective of performing probabilistic inference.
  2. Embeddings from the base LLM, aggregated to use as input along with the class label.
Note

Gaussian logistic regression consists of regressing a given feature vector to a discrete binary label.

1. Reasoning module

This module is based on GPT-2 and taught to predict the next item in a sequence (autoregressive).

We denote the input sequence with labeled examples (x_1, y_1), (x_2, y_2), \dots , (x_k, y_k), with each example z_i = (x_i, y_i) using only two tokens (one for x, one for y). In comparison, standard LLMs would use multiple tokens to encode x, limiting the number of samples in the context window.

Training-wise, gradient descent is used to minimize a loss calculated using the cross-entropy function. The training sequences each represent a unique logistic regression problem, with parameters and features drawn from standard normal distributions.

The logistic output y is computed using a sigmoid function applied to a scaled dot product of feature vector and parameters, where the scaling factor represents the noise level.

Model hyperparameters are set to a 16-dimensional input space, with labels encoded in this space using one-hot encoding. Interestingly, the paper reduces the high dimensionality of output embeddings by performing PCA using only the training points available for the specific task at hand.

64 different logistic regression problems were tested by comparing logistic regression classifiers vs. their reasoning module (Figure a). The error of the reasoning module was found to decrease with the increase in the number of examples, and was within 2% of the task-specific logistic function.

The noise level in logistic regression problems is an indicator of the problem’s difficulty, with lower values of α signifying more difficult problems. The reasoning module was found to handle easier problems (with higher noise levels) without any drop in accuracy. However, for harder problems, the module’s error increased progressively (Figure b).

2. Embeddings

The authors compared two types of embeddings:

  • Vanilla embeddings: They use all training examples and get the averaged embedding vectors. However, performance is degraded when there are too many samples in the context window.
  • Leave-One-Out (LOO): The embedding for each training point is created by placing all other training examples before it in the prompt and averaging the embeddings over the final example’s tokens.

Experiments

Tart was validated on different binary classification tasks in NLP, vision, and audio. It is used with GPT-Neo (125M).

Averaged over all tasks and model families, Tart improves upon base in-context learning performance by 18.4 points, improves upon adapter heads by 3.4 points, and is within 3.1 points of full fine-tuning.